feat(pt_expt): add dp freeze support and dp test tests for .pte models#5302
Conversation
Add freeze() function to pt_expt backend that loads a .pt checkpoint, reconstructs the model, serializes it, and exports to .pte via deserialize_to_file. Wire the freeze command in the main() CLI dispatcher. Add separate test files for dp freeze (test_dp_freeze.py) and dp test (test_dp_test.py) verifying the full freeze-then-test pipeline works end-to-end with .pte models.
There was a problem hiding this comment.
💡 Codex Review
Here are some automated review suggestions for this pull request.
Reviewed commit: 35745c127e
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds a new "freeze" CLI command and a Changes
Sequence DiagramsequenceDiagram
participant CLI as main()
participant Freeze as freeze()
participant ModelBuilder as get_model()/ModelWrapper
participant Storage as Filesystem
CLI->>Freeze: freeze(checkpoint_path, output_file, head)
Freeze->>Freeze: resolve checkpoint path (dir → model.ckpt.pt)
Freeze->>Storage: read checkpoint (state_dict, _extra_state)
Freeze->>ModelBuilder: get_model() → instantiate model
Freeze->>ModelBuilder: wrap model, load state_dict
ModelBuilder->>Freeze: wrapped model ready
Freeze->>Storage: deserialize_to_file(wrapped_model, output.pte)
Storage->>CLI: saved path / success log
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested labels
Suggested reviewers
🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (2)
source/tests/pt_expt/test_dp_test.py (1)
27-45: Consider extracting shared test fixtures.The
model_se_e2_aconfiguration and checkpoint creation pattern are duplicated intest_dp_freeze.py. Consider extracting these to a shared conftest or helper module to reduce duplication.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@source/tests/pt_expt/test_dp_test.py` around lines 27 - 45, The model configuration dictionary model_se_e2_a and the checkpoint creation logic duplicated between test_dp_test.py and test_dp_freeze.py should be extracted into a shared pytest fixture or helper function (e.g., in conftest.py or a test_helpers module); create a fixture named model_se_e2_a that returns the dict and a helper fixture/function (e.g., make_checkpoint or checkpoint_fixture) that encapsulates the checkpoint creation pattern, then update both tests to accept those fixtures instead of redefining the dict/checkpoint code so duplication is removed and maintenance is centralized.deepmd/pt_expt/entrypoints/main.py (1)
256-257: Minor:.pt2suffix check might be undocumented.The code accepts both
.pteand.pt2suffixes, but the docstring and default only mention.pte. Consider documenting.pt2if it's intentionally supported, or remove it if not needed.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 256 - 257, The FLAGS.output handling accepts both ".pte" and ".pt2" but only ".pte" is documented; decide whether ".pt2" is intentional and then update code accordingly: if intended, add ".pt2" to the module/docstring and the FLAGS.output help/default text (where FLAGS is defined) and update any docs/tests to mention ".pt2"; otherwise remove ".pt2" from the tuple in the conditional so FLAGS.output only normalizes to ".pte". Ensure changes reference FLAGS.output and the suffix check in main.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 250-253: The code reading the checkpoint name using
(checkpoint_path / "checkpoint").read_text() assigns latest_ckpt_file with
possible trailing newline/whitespace which breaks FLAGS.model path construction;
update the read to strip whitespace (e.g., call .strip() on the result) before
using checkpoint_path.joinpath and set FLAGS.model =
str(checkpoint_path.joinpath(latest_ckpt_file.strip())), ensuring you reference
FLAGS.checkpoint_folder, checkpoint_path, latest_ckpt_file and FLAGS.model when
making the change.
---
Nitpick comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 256-257: The FLAGS.output handling accepts both ".pte" and ".pt2"
but only ".pte" is documented; decide whether ".pt2" is intentional and then
update code accordingly: if intended, add ".pt2" to the module/docstring and the
FLAGS.output help/default text (where FLAGS is defined) and update any
docs/tests to mention ".pt2"; otherwise remove ".pt2" from the tuple in the
conditional so FLAGS.output only normalizes to ".pte". Ensure changes reference
FLAGS.output and the suffix check in main.py.
In `@source/tests/pt_expt/test_dp_test.py`:
- Around line 27-45: The model configuration dictionary model_se_e2_a and the
checkpoint creation logic duplicated between test_dp_test.py and
test_dp_freeze.py should be extracted into a shared pytest fixture or helper
function (e.g., in conftest.py or a test_helpers module); create a fixture named
model_se_e2_a that returns the dict and a helper fixture/function (e.g.,
make_checkpoint or checkpoint_fixture) that encapsulates the checkpoint creation
pattern, then update both tests to accept those fixtures instead of redefining
the dict/checkpoint code so duplication is removed and maintenance is
centralized.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: e8b4749c-4b84-4f85-a10c-9f671ace384a
📒 Files selected for processing (3)
deepmd/pt_expt/entrypoints/main.pysource/tests/pt_expt/test_dp_freeze.pysource/tests/pt_expt/test_dp_test.py
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #5302 +/- ##
==========================================
- Coverage 82.32% 82.30% -0.02%
==========================================
Files 768 775 +7
Lines 77098 77664 +566
Branches 3659 3675 +16
==========================================
+ Hits 63469 63924 +455
- Misses 12458 12567 +109
- Partials 1171 1173 +2 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
pt_expt training saves checkpoints as model.ckpt-{step}.pt with a
model.ckpt.pt symlink, not a "checkpoint" text file. The previous
code was copied from the pt backend which uses a different format.
There was a problem hiding this comment.
Actionable comments posted: 2
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 194-198: The checkpoint loading code assumes a nested schema and
directly accesses state_dict["_extra_state"]["model_params"], which can raise
uninformative KeyError; update the block after torch.load(model, ...) to
explicitly validate the schema: ensure state_dict is a dict (after handling the
optional "model" wrapper), that "_extra_state" is present and is a dict, and
that "model_params" exists inside it, and if any check fails raise a clear
ValueError mentioning the expected keys and the actual top-level keys (include
the model identifier variable name), so callers get an actionable error instead
of a raw KeyError when accessing model_params.
- Around line 261-263: The code currently assigns FLAGS.model =
FLAGS.checkpoint_folder without validating that the path exists, deferring
errors to torch.load; update the else branch that sets FLAGS.model (using
FLAGS.checkpoint_folder) to immediately check the filesystem: if the path is an
existing directory or an existing file accept it, otherwise raise a clear CLI
error (e.g., call parser.error or sys.exit with a descriptive message) so the
user fails fast before reaching torch.load; reference FLAGS.model and
FLAGS.checkpoint_folder in the check and ensure the error message mentions the
invalid checkpoint path.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: f707d44b-ebab-4908-9e8f-4aee84d87b86
📒 Files selected for processing (1)
deepmd/pt_expt/entrypoints/main.py
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
deepmd/pt_expt/entrypoints/main.py (1)
194-199:⚠️ Potential issue | 🟡 MinorValidate checkpoint root type before nested access.
state_dict.get("_extra_state")assumes a dict-like object. If a malformed/incompatible checkpoint is loaded, this path raises an opaqueAttributeErrorinstead of a clear CLI error.Proposed fix
state_dict = torch.load(model, map_location=DEVICE, weights_only=True) if "model" in state_dict: state_dict = state_dict["model"] + if not isinstance(state_dict, dict): + raise ValueError( + f"Unsupported checkpoint format at '{model}': " + f"expected dict-like state_dict, got {type(state_dict).__name__}." + ) + extra_state = state_dict.get("_extra_state") if not isinstance(extra_state, dict) or "model_params" not in extra_state: raise ValueError( f"Unsupported checkpoint format at '{model}': missing " "'_extra_state.model_params' in model state dict." )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@deepmd/pt_expt/entrypoints/main.py` around lines 194 - 199, The code assumes state_dict is a mapping and does state_dict.get("_extra_state") which can raise AttributeError for malformed checkpoints; update the loading logic around torch.load(..., weights_only=True) and the subsequent state_dict handling (variable state_dict and extra_state) to first verify state_dict is a dict-like object (e.g., isinstance(state_dict, dict)) and only then attempt to read "_extra_state" and "model_params"; if the type check fails or "_extra_state" / "model_params" are missing, raise a clear CLI-friendly error/exception with context (including the checkpoint identifier) instead of letting AttributeError propagate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 206-209: The current guard only raises NotImplementedError when
both head is provided and "model_dict" is in model_params, but we should block
any multi-task checkpoint regardless of head; update the check in main.py (the
block containing head and model_params) to raise NotImplementedError whenever
"model_dict" is present in model_params (remove the head requirement) so
multi-task checkpoints are unconditionally rejected by the pt_expt backend.
---
Duplicate comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 194-199: The code assumes state_dict is a mapping and does
state_dict.get("_extra_state") which can raise AttributeError for malformed
checkpoints; update the loading logic around torch.load(..., weights_only=True)
and the subsequent state_dict handling (variable state_dict and extra_state) to
first verify state_dict is a dict-like object (e.g., isinstance(state_dict,
dict)) and only then attempt to read "_extra_state" and "model_params"; if the
type check fails or "_extra_state" / "model_params" are missing, raise a clear
CLI-friendly error/exception with context (including the checkpoint identifier)
instead of letting AttributeError propagate.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: 7dd197e6-754f-4544-958a-d2e0930300c4
📒 Files selected for processing (1)
deepmd/pt_expt/entrypoints/main.py
A multi-task checkpoint without --head would silently fall through and produce wrong results. Remove the head check so any multi-task checkpoint raises NotImplementedError regardless of --head flag.
…t-pt-expt-dpfrztest
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@deepmd/pt_expt/entrypoints/main.py`:
- Around line 275-276: The output normalization should only accept the .pte
extension: change the condition that checks FLAGS.output so it only allows
".pte" (i.e., replace the endswith((".pte", ".pt2")) check with
endswith(".pte")), and if it doesn't, set FLAGS.output =
str(Path(FLAGS.output).with_suffix(".pte")); update the logic around
FLAGS.output and Path.with_suffix to drop support for ".pt2" so downstream
loaders (serialize_from_file, DeepEval, torch.export.load) always get a .pte
file.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Repository UI
Review profile: CHILL
Plan: Pro
Run ID: a2074e19-8dbe-4a03-979b-0be4d7ba4ba2
📒 Files selected for processing (1)
deepmd/pt_expt/entrypoints/main.py
| if not FLAGS.output.endswith((".pte", ".pt2")): | ||
| FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Verify whether `.pt2` is supported wherever `.pte` is consumed.
# Expected: extension routing/loading code should explicitly accept both,
# otherwise prefer normalizing freeze output to `.pte` only.
rg -nP --type=py -C3 '\.pte\b|\.pt2\b'
rg -nP --type=py -C4 'freeze|test|infer|load.*model|endswith\(\(.*\.pte|\.pt2'Repository: deepmodeling/deepmd-kit
Length of output: 50381
🏁 Script executed:
# Search for code that loads .pte or .pt2 files in inference/test paths
rg -n --type=py '\.pte|\.pt2' --max-count=100 deepmd/pt_expt/infer/ deepmd/pt_expt/utils/serialization.py deepmd/backend/pt_expt.py
# Check the DeepEval class that loads models
rg -n --type=py -B2 -A5 'class DeepEval|def.*__init__' deepmd/pt_expt/infer/deep_eval.py | head -50
# Search for any explicit .pt2 handling
rg -n --type=py 'pt2' deepmd/Repository: deepmodeling/deepmd-kit
Length of output: 1948
🏁 Script executed:
# Check if .pt2 is used anywhere else in tests or loading code
rg -n 'pt2' source/tests/ deepmd/pt_expt/
# Check if there's any model loading with file extension handling
rg -n --type=py 'load.*pte|pte.*load|suffix' deepmd/pt_expt/infer/ deepmd/pt_expt/utils/
# Look for file operations that might accept both extensions
rg -n --type=py 'endswith|suffix|extension' deepmd/pt_expt/ -A2 -B2 | head -80Repository: deepmodeling/deepmd-kit
Length of output: 4004
🏁 Script executed:
# Final verification: check if pt2 is referenced in any way outside main.py
rg -r 'pt2' --type=py . | grep -v 'test_hybrid_muon\|test_adamuon\|descrpt2r4'
# Verify the backend suffixes definition once more
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'Repository: deepmodeling/deepmd-kit
Length of output: 50382
🏁 Script executed:
# Get the backend file content to see the suffix definition
cat -n deepmd/backend/pt_expt.py | sed -n '40,50p'
# Search more specifically for code that reads .pte files
rg -n --type=py 'torch.export.load|load.*pte' deepmd/Repository: deepmodeling/deepmd-kit
Length of output: 654
Tighten output normalization to accept only .pte extension.
Lines 275-276 preserve both .pte and .pt2 suffixes, but the backend and all loaders (serialize_from_file, DeepEval, torch.export.load) only support .pte. If users provide .pt2 output, the resulting files cannot be loaded downstream, creating broken artifacts.
Suggested fix
- if not FLAGS.output.endswith((".pte", ".pt2")):
+ if not FLAGS.output.endswith(".pte"):
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte"))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@deepmd/pt_expt/entrypoints/main.py` around lines 275 - 276, The output
normalization should only accept the .pte extension: change the condition that
checks FLAGS.output so it only allows ".pte" (i.e., replace the
endswith((".pte", ".pt2")) check with endswith(".pte")), and if it doesn't, set
FLAGS.output = str(Path(FLAGS.output).with_suffix(".pte")); update the logic
around FLAGS.output and Path.with_suffix to drop support for ".pt2" so
downstream loaders (serialize_from_file, DeepEval, torch.export.load) always get
a .pte file.
Summary
dp freezesupport for the pt_expt backend, enabling checkpoint.pt→ exported.pteconversiondp freezeanddp testwith.ptemodelsBackground
The pt_expt backend can export models to
.pteviadeserialize_to_file(), anddp testcan already load.ptemodels through the registeredDeepEval. However,dp freezewas notwired up — calling
dp freeze -b pt-expthitRuntimeError: Unsupported command 'freeze'.Changes
deepmd/pt_expt/entrypoints/main.pyfreeze()function: loads.ptcheckpoint → reconstructs model viaget_model+ModelWrapper→ serializes → exports to.pteviadeserialize_to_filefreezecommand inmain()dispatcher with checkpoint directory resolution and.ptedefault suffixsource/tests/pt_expt/test_dp_freeze.py(new)test_freeze_pte— verify.ptefile is created from checkpointtest_freeze_main_dispatcher— testmain()CLI dispatcher with freeze commandtest_freeze_default_suffix— verify non-.pteoutput suffix is corrected to.ptesource/tests/pt_expt/test_dp_test.py(new)test_dp_test_system— testdp testwith-ssystem path, verify.e.out,.f.out,.v.outoutputstest_dp_test_input_json— testdp testwith--valid-dataJSON inputTest plan
python -m pytest source/tests/pt_expt/test_dp_freeze.py -v(3 passed)python -m pytest source/tests/pt_expt/test_dp_test.py -v(2 passed)Summary by CodeRabbit